//! Utility functionality

use std::collections::HashMap;
use std::fs::{self, DirBuilder, File};
use std::os::linux::fs::MetadataExt;
use std::os::unix::fs::DirBuilderExt;
use std::path::{Component, Path, PathBuf};
use std::time::Duration;

use libc::IFNAMSIZ;
use nix::sys::stat::Mode;
use nix::unistd::{Uid, User};
use oci_spec::runtime::{LinuxNamespaceType, Spec};

use crate::error::{LibcontainerError, MissingSpecError};
use crate::syscall::syscall::Syscall;
use crate::user_ns::UserNamespaceConfig;

#[derive(Debug, thiserror::Error)]
pub enum PathBufExtError {
    #[error("relative path cannot be converted to the path in the container")]
    RelativePath,
    #[error("failed to strip prefix from {path:?}")]
    StripPrefix {
        path: PathBuf,
        source: std::path::StripPrefixError,
    },
    #[error("failed to canonicalize path {path:?}")]
    Canonicalize {
        path: PathBuf,
        source: std::io::Error,
    },
    #[error("failed to get current directory")]
    CurrentDir { source: std::io::Error },
}

pub trait PathBufExt {
    fn as_relative(&self) -> Result<&Path, PathBufExtError>;
    fn join_safely<P: AsRef<Path>>(&self, p: P) -> Result<PathBuf, PathBufExtError>;
    fn canonicalize_safely(&self) -> Result<PathBuf, PathBufExtError>;
    fn normalize(&self) -> PathBuf;
}

impl PathBufExt for Path {
    fn as_relative(&self) -> Result<&Path, PathBufExtError> {
        match self.is_relative() {
            true => Err(PathBufExtError::RelativePath),
            false => Ok(self
                .strip_prefix("/")
                .map_err(|e| PathBufExtError::StripPrefix {
                    path: self.to_path_buf(),
                    source: e,
                })?),
        }
    }

    fn join_safely<P: AsRef<Path>>(&self, path: P) -> Result<PathBuf, PathBufExtError> {
        let path = path.as_ref();
        if path.is_relative() {
            return Ok(self.join(path));
        }

        let stripped = path
            .strip_prefix("/")
            .map_err(|e| PathBufExtError::StripPrefix {
                path: self.to_path_buf(),
                source: e,
            })?;
        Ok(self.join(stripped))
    }

    /// Canonicalizes existing and not existing paths
    fn canonicalize_safely(&self) -> Result<PathBuf, PathBufExtError> {
        if self.exists() {
            self.canonicalize()
                .map_err(|e| PathBufExtError::Canonicalize {
                    path: self.to_path_buf(),
                    source: e,
                })
        } else {
            if self.is_relative() {
                let p = std::env::current_dir()
                    .map_err(|e| PathBufExtError::CurrentDir { source: e })?
                    .join(self);
                return Ok(p.normalize());
            }

            Ok(self.normalize())
        }
    }

    /// Normalizes a path. In contrast to canonicalize the path does not need to exist.
    // adapted from https://github.com/rust-lang/cargo/blob/fede83ccf973457de319ba6fa0e36ead454d2e20/src/cargo/util/paths.rs#L61
    fn normalize(&self) -> PathBuf {
        let mut components = self.components().peekable();
        let mut ret = if let Some(c @ Component::Prefix(..)) = components.peek().cloned() {
            components.next();
            PathBuf::from(c.as_os_str())
        } else {
            PathBuf::new()
        };

        for component in components {
            match component {
                Component::Prefix(..) => unreachable!(),
                Component::RootDir => {
                    ret.push(component.as_os_str());
                }
                Component::CurDir => {}
                Component::ParentDir => {
                    ret.pop();
                }
                Component::Normal(c) => {
                    ret.push(c);
                }
            }
        }
        ret
    }
}

pub fn parse_env(envs: &[String]) -> HashMap<String, String> {
    envs.iter()
        .filter_map(|e| {
            let mut split = e.split('=');

            split.next().map(|key| {
                let value = split.collect::<Vec<&str>>().join("=");
                (key.into(), value)
            })
        })
        .collect()
}

/// Get a nix::unistd::User via UID. Potential errors will be ignored.
pub fn get_unix_user(uid: Uid) -> Option<User> {
    User::from_uid(uid).unwrap_or_default()
}

/// Get home path of a User via UID.
pub fn get_user_home(uid: u32) -> Option<PathBuf> {
    match get_unix_user(Uid::from_raw(uid)) {
        Some(user) => Some(user.dir),
        None => None,
    }
}

/// If None, it will generate a default path for cgroups.
pub fn get_cgroup_path(cgroups_path: &Option<PathBuf>, container_id: &str) -> PathBuf {
    match cgroups_path {
        Some(cpath) => cpath.clone(),
        None => PathBuf::from(format!(":youki:{container_id}")),
    }
}

pub fn write_file<P: AsRef<Path>, C: AsRef<[u8]>>(
    path: P,
    contents: C,
) -> Result<(), std::io::Error> {
    fs::write(path.as_ref(), contents).map_err(|err| {
        tracing::error!(path = ?path.as_ref(), ?err, "failed to write file");
        err
    })?;

    Ok(())
}

pub fn create_dir_all<P: AsRef<Path>>(path: P) -> Result<(), std::io::Error> {
    fs::create_dir_all(path.as_ref()).map_err(|err| {
        tracing::error!(path = ?path.as_ref(), ?err, "failed to create directory");
        err
    })?;
    Ok(())
}

pub fn open<P: AsRef<Path>>(path: P) -> Result<File, std::io::Error> {
    File::open(path.as_ref()).map_err(|err| {
        tracing::error!(path = ?path.as_ref(), ?err, "failed to open file");
        err
    })
}

#[derive(Debug, thiserror::Error)]
pub enum MkdirWithModeError {
    #[error("IO error")]
    Io(#[from] std::io::Error),
    #[error("metadata doesn't match the expected attributes")]
    MetadataMismatch,
}

/// Creates the specified directory and all parent directories with the specified mode. Ensures
/// that the directory has been created with the correct mode and that the owner of the directory
/// is the owner that has been specified
/// # Example
/// ``` no_run
/// use libcontainer::utils::create_dir_all_with_mode;
/// use nix::sys::stat::Mode;
/// use std::path::Path;
///
/// let path = Path::new("/tmp/youki");
/// create_dir_all_with_mode(&path, 1000, Mode::S_IRWXU).unwrap();
/// assert!(path.exists())
/// ```
pub fn create_dir_all_with_mode<P: AsRef<Path>>(
    path: P,
    owner: u32,
    mode: Mode,
) -> Result<(), MkdirWithModeError> {
    let path = path.as_ref();
    if !path.exists() {
        DirBuilder::new()
            .recursive(true)
            .mode(mode.bits())
            .create(path)?;
    }

    let metadata = path.metadata()?;
    if metadata.is_dir()
        && metadata.st_uid() == owner
        && metadata.st_mode() & mode.bits() == mode.bits()
    {
        Ok(())
    } else {
        Err(MkdirWithModeError::MetadataMismatch)
    }
}

pub fn is_in_new_userns() -> Result<bool, std::io::Error> {
    let uid_map_path = "/proc/self/uid_map";
    let content = std::fs::read_to_string(uid_map_path)?;
    Ok(!content.contains("4294967295"))
}

/// Checks if rootless mode needs to be used
pub fn rootless_required(syscall: &dyn Syscall) -> Result<bool, std::io::Error> {
    if !syscall.get_euid().is_root() {
        return Ok(true);
    }
    is_in_new_userns()
}

/// checks if given spec is valid for current user namespace setup
pub fn validate_spec_for_new_user_ns(
    spec: &Spec,
    syscall: &dyn Syscall,
) -> Result<(), LibcontainerError> {
    let config = UserNamespaceConfig::new(spec)?;
    let in_user_ns = is_in_new_userns().map_err(LibcontainerError::OtherIO)?;
    let is_rootless_required = rootless_required(syscall).map_err(LibcontainerError::OtherIO)?;
    // In case of rootless, there are 2 possible cases :
    // we have a new user ns specified in the spec
    // or the youki is launched in a new user ns (this is how podman does it)
    // So here, we check if rootless is required,
    // but we are neither in a new user ns nor a new user ns is specified in spec
    // then it is an error
    if is_rootless_required && !in_user_ns && config.is_none() {
        return Err(LibcontainerError::NoUserNamespace);
    }
    Ok(())
}

// Generic retry function with delay and policy.
// Retries the operation `op` up to `attempts` times if it fails.
// Waits for `delay` duration between retries.
// Only retries if the error satisfies the `policy` function.
pub fn retry<F, T, E, P>(mut op: F, attempts: u32, delay: Duration, policy: P) -> Result<T, E>
where
    F: FnMut() -> Result<T, E>,
    P: Fn(&E) -> bool,
{
    if attempts == 0 {
        panic!("retry called with 0 attempts. Minimum attempts is 1.");
    }
    for attempt in 0..attempts {
        match op() {
            Ok(res) => return Ok(res),
            Err(err) => {
                if attempt + 1 < attempts && policy(&err) {
                    std::thread::sleep(delay);
                } else {
                    return Err(err);
                }
            }
        }
    }
    unreachable!("retry loop completed without returning a result.");
}

#[derive(Debug, thiserror::Error)]
pub enum NetDevicesError {
    #[error("unable to move network devices without a NET namespace")]
    NoNetNamespace,
    #[error("network devices are not supported in rootless containers")]
    RootlessNotSupported,
    #[error("invalid network device name: {0}")]
    InvalidDeviceName(String),
    #[error(transparent)]
    IO(#[from] std::io::Error),
    #[error(transparent)]
    Spec(#[from] MissingSpecError),
}

// check if given spec is valid for netDevices
pub fn validate_spec_for_net_devices(
    spec: &Spec,
    syscall: &dyn Syscall,
) -> Result<(), NetDevicesError> {
    let linux = spec
        .linux()
        .as_ref()
        .ok_or(NetDevicesError::Spec(MissingSpecError::Linux))?;

    if linux.net_devices().is_none() {
        return Ok(());
    }

    let has_net_namespace = match linux.namespaces() {
        Some(namespaces) => namespaces
            .iter()
            .any(|ns| ns.typ() == LinuxNamespaceType::Network),
        None => false,
    };

    if !has_net_namespace {
        return Err(NetDevicesError::NoNetNamespace);
    }

    let is_rootless = rootless_required(syscall).map_err(NetDevicesError::IO)?;
    if is_rootless {
        return Err(NetDevicesError::RootlessNotSupported);
    }

    if let Some(devices) = linux.net_devices() {
        devices.iter().try_for_each(|(name, net_dev)| {
            if !dev_valid_name(name) {
                return Err(NetDevicesError::InvalidDeviceName(name.into()));
            }
            if let Some(dev_name) = net_dev.name() {
                if !dev_valid_name(dev_name) {
                    return Err(NetDevicesError::InvalidDeviceName(dev_name.into()));
                }
            }
            Ok(())
        })?;
    }

    Ok(())
}

// https://elixir.bootlin.com/linux/v6.12/source/net/core/dev.c#L1066
fn dev_valid_name(name: &str) -> bool {
    if name.is_empty() || name.len() > IFNAMSIZ {
        return false;
    }
    if name.eq(".") || name.eq("..") {
        return false;
    }

    for c in name.chars() {
        if c == '/' || c == ':' || c.is_whitespace() {
            return false;
        }
    }

    true
}

#[cfg(test)]
mod tests {
    use core::panic;

    use anyhow::{Result, bail};
    use nix::unistd::Gid;
    use oci_spec::runtime::{LinuxBuilder, LinuxNamespaceBuilder, LinuxNetDevice, SpecBuilder};
    use serial_test::serial;

    use super::*;
    use crate::syscall::syscall::create_syscall;
    use crate::test_utils;

    #[test]
    pub fn test_get_unix_user() {
        let user = get_unix_user(Uid::from_raw(0));
        assert_eq!(user.unwrap().name, "root");

        // for a non-exist UID
        let user = get_unix_user(Uid::from_raw(1000000000));
        assert!(user.is_none());
    }

    #[test]
    pub fn test_get_user_home() {
        let dir = get_user_home(0);
        assert_eq!(dir.unwrap().to_str().unwrap(), "/root");

        // for a non-exist UID
        let dir = get_user_home(1000000000);
        assert!(dir.is_none());
    }

    #[test]
    fn test_get_cgroup_path() {
        let cid = "sample_container_id";
        assert_eq!(
            get_cgroup_path(&None, cid),
            PathBuf::from(":youki:sample_container_id")
        );
        assert_eq!(
            get_cgroup_path(&Some(PathBuf::from("/youki")), cid),
            PathBuf::from("/youki")
        );
    }

    #[test]
    fn test_parse_env() -> Result<()> {
        let key = "key".into();
        let value = "value".into();
        let env_input = vec![format!("{key}={value}")];
        let env_output = parse_env(&env_input);
        assert_eq!(
            env_output.len(),
            1,
            "There should be exactly one entry inside"
        );
        assert_eq!(env_output.get_key_value(&key), Some((&key, &value)));

        Ok(())
    }

    #[test]
    fn test_create_dir_all_with_mode() -> Result<()> {
        {
            let temdir = tempfile::tempdir()?;
            let path = temdir.path().join("test");
            let syscall = create_syscall();
            let uid = syscall.get_uid().as_raw();
            let mode = Mode::S_IRWXU;
            create_dir_all_with_mode(&path, uid, mode)?;
            let metadata = path.metadata()?;
            assert!(path.is_dir());
            assert_eq!(metadata.st_uid(), uid);
            assert_eq!(metadata.st_mode() & mode.bits(), mode.bits());
        }
        {
            let temdir = tempfile::tempdir()?;
            let path = temdir.path().join("test");
            let mode = Mode::S_IRWXU;
            std::fs::create_dir(&path)?;
            assert!(path.is_dir());
            match create_dir_all_with_mode(&path, 8899, mode) {
                Err(MkdirWithModeError::MetadataMismatch) => {}
                _ => bail!("should return MetadataMismatch"),
            }
        }
        Ok(())
    }

    #[test]
    fn test_io() -> Result<()> {
        {
            let tempdir = tempfile::tempdir()?;
            let path = tempdir.path().join("test");
            write_file(&path, "test".as_bytes())?;
            open(&path)?;
            assert!(create_dir_all(path).is_err());
        }
        {
            let tempdir = tempfile::tempdir()?;
            let path = tempdir.path().join("test");
            create_dir_all(&path)?;
            assert!(write_file(&path, "test".as_bytes()).is_err());
        }
        {
            let tempdir = tempfile::tempdir()?;
            let path = tempdir.path().join("test");
            assert!(open(&path).is_err());
            create_dir_all(&path)?;
            assert!(path.is_dir())
        }

        Ok(())
    }

    // the following test is marked as serial because
    // we are doing unshare of user ns and fork, so better to run in serial,
    #[test]
    #[serial]
    fn test_userns_spec_validation() -> Result<(), test_utils::TestError> {
        use nix::sched::{CloneFlags, unshare};
        let syscall = create_syscall();
        // default rootful spec
        let rootful_spec = Spec::default();
        // as we are not in a user ns, and spec does not have user ns
        // we should get error here
        assert!(validate_spec_for_new_user_ns(&rootful_spec, &*syscall).is_err());

        let rootless_spec = Spec::rootless(1000, 1000);
        // because the spec contains user ns info, we should not get error
        assert!(validate_spec_for_new_user_ns(&rootless_spec, &*syscall).is_ok());

        test_utils::test_in_child_process(|| {
            unshare(CloneFlags::CLONE_NEWUSER).unwrap();
            // here we are in a new user namespace
            let rootful_spec = Spec::default();
            let syscall = create_syscall();
            // because we are already in a new user ns, it is fine if spec
            // does not have user ns, and because the test is running as
            // non root
            assert!(validate_spec_for_new_user_ns(&rootful_spec, &*syscall).is_ok());

            let rootless_spec = Spec::rootless(1000, 1000);
            // following should succeed irrespective if we're in user ns or not
            assert!(validate_spec_for_new_user_ns(&rootless_spec, &*syscall).is_ok());
            Ok(())
        })
    }

    #[test]
    fn test_dev_valid_name() {
        assert!(!dev_valid_name(""));

        let long_name = "a".repeat(IFNAMSIZ + 1);
        assert!(!dev_valid_name(&long_name));

        let valid_name = "a".repeat(IFNAMSIZ);
        assert!(dev_valid_name(&valid_name));

        assert!(!dev_valid_name("."));
        assert!(!dev_valid_name(".."));

        assert!(!dev_valid_name("/: "));
        assert!(!dev_valid_name("eth0/: "));

        assert!(dev_valid_name("eth0"));
        assert!(dev_valid_name("veth123"));
        assert!(dev_valid_name("abc.def"));
    }

    fn build_spec_with_ns_and_devices(include_net_ns: bool, devices: Vec<(&str, &str)>) -> Spec {
        let mut namespaces = vec![];
        if include_net_ns {
            namespaces.push(
                LinuxNamespaceBuilder::default()
                    .typ(LinuxNamespaceType::Network)
                    .path(PathBuf::from("/dev/net"))
                    .build()
                    .unwrap(),
            );
        }

        let net_devices: HashMap<String, LinuxNetDevice> = devices
            .into_iter()
            .map(|(key, val)| {
                (
                    key.into(),
                    LinuxNetDevice::default().set_name(Some(val.into())).clone(),
                )
            })
            .collect();
        let linux = LinuxBuilder::default()
            .namespaces(namespaces)
            .net_devices(net_devices)
            .build()
            .unwrap();

        SpecBuilder::default().linux(linux).build().unwrap()
    }

    #[test]
    fn test_net_devices_none() {
        let spec = Spec::default();
        let syscall = create_syscall();
        syscall.set_id(Uid::from_raw(0), Gid::from_raw(0)).unwrap();
        let result = validate_spec_for_net_devices(&spec, &*syscall);
        assert!(result.is_ok());
    }

    #[test]
    fn test_missing_net_namespace() {
        let spec = build_spec_with_ns_and_devices(false, vec![]);
        let syscall = create_syscall();
        let err = validate_spec_for_net_devices(&spec, &*syscall).unwrap_err();
        assert!(matches!(err, NetDevicesError::NoNetNamespace));
    }

    #[test]
    fn test_invalid_device_name() {
        let spec = build_spec_with_ns_and_devices(true, vec![("eth0", "/:invalid")]);
        let syscall = create_syscall();
        syscall.set_id(Uid::from_raw(0), Gid::from_raw(0)).unwrap();
        let err = validate_spec_for_net_devices(&spec, &*syscall).unwrap_err();
        if let NetDevicesError::InvalidDeviceName(name) = err {
            assert_eq!(name, "/:invalid");
        } else {
            panic!("Expected InvalidDeviceName error");
        }
    }

    #[test]
    fn test_valid_config() {
        let spec = build_spec_with_ns_and_devices(true, vec![("eth0", "eth0_container")]);
        let syscall = create_syscall();
        syscall.set_id(Uid::from_raw(0), Gid::from_raw(0)).unwrap();
        let result = validate_spec_for_net_devices(&spec, &*syscall);
        assert!(result.is_ok());
    }
}
